{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from pytorch_model_summary import summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how HybridIDF works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "            self.targets = digits.target[:1000]\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "            self.targets = digits.target[1000:1350]\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "            self.targets = digits.target[1350:]\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample_x = self.data[idx]\n",
    "        sample_y = self.targets[idx]\n",
    "        if self.transforms:\n",
    "            sample_x = self.transforms(sample_x)\n",
    "        return (sample_x, sample_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Chakraborty & Chakravarty, \"A new discrete probability distribution with integer support on (−∞, ∞)\",\n",
    "#  Communications in Statistics - Theory and Methods, 45:2, 492-505, DOI: 10.1080/03610926.2013.830743\n",
    "\n",
    "def log_min_exp(a, b, epsilon=1e-8):\n",
    "    \"\"\"\n",
    "    Source: https://github.com/jornpeters/integer_discrete_flows\n",
    "    Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.\n",
    "    Using:\n",
    "     log(exp(a) - exp(b))\n",
    "     c + log(exp(a-c) - exp(b-c))\n",
    "     a + log(1 - exp(b-a))\n",
    "    And note that we assume b < a always.\n",
    "    \"\"\"\n",
    "    y = a + torch.log(1 - torch.exp(b - a) + epsilon)\n",
    "\n",
    "    return y\n",
    "\n",
    "def log_integer_probability(x, mean, logscale):\n",
    "    scale = torch.exp(logscale)\n",
    "\n",
    "    logp = log_min_exp(\n",
    "        F.logsigmoid((x + 0.5 - mean) / scale),\n",
    "        F.logsigmoid((x - 0.5 - mean) / scale))\n",
    "\n",
    "    return logp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HybridIDF code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please see the blogpost for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Source: https://github.com/jornpeters/integer_discrete_flows\n",
    "class RoundStraightThrough(torch.autograd.Function):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    @staticmethod\n",
    "    def forward(ctx, input):\n",
    "        rounded = torch.round(input, out=None)\n",
    "        return rounded\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        grad_input = grad_output.clone()\n",
    "        return grad_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HybridIDF(nn.Module):\n",
    "    def __init__(self, netts, classnet, num_flows, alpha=1., D=2):\n",
    "        super(HybridIDF, self).__init__()\n",
    "\n",
    "        print('HybridIDF by JT.')\n",
    "        \n",
    "        if len(netts) == 1:\n",
    "            self.t = torch.nn.ModuleList([netts[0]() for _ in range(num_flows)])\n",
    "            self.idf_git = 1\n",
    "            self.beta = nn.Parameter(torch.zeros(len(self.t)))\n",
    "        \n",
    "        elif len(netts) == 4:\n",
    "            self.t_a = torch.nn.ModuleList([netts[0]() for _ in range(num_flows)])\n",
    "            self.t_b = torch.nn.ModuleList([netts[1]() for _ in range(num_flows)])\n",
    "            self.t_c = torch.nn.ModuleList([netts[2]() for _ in range(num_flows)])\n",
    "            self.t_d = torch.nn.ModuleList([netts[3]() for _ in range(num_flows)])\n",
    "            self.idf_git = 4\n",
    "            self.beta = nn.Parameter(torch.zeros(len(self.t_a)))\n",
    "        \n",
    "        else:\n",
    "            raise ValueError('You can provide either 1 or 4 translation nets.')\n",
    "        \n",
    "        self.classnet = classnet\n",
    "        \n",
    "        self.num_flows = num_flows\n",
    "\n",
    "        self.round = RoundStraightThrough.apply\n",
    "        \n",
    "        self.mean = nn.Parameter(torch.zeros(1, D))\n",
    "        self.logscale = nn.Parameter(torch.ones(1, D))\n",
    "\n",
    "        self.D = D\n",
    "        \n",
    "        self.alpha = alpha\n",
    "        \n",
    "        self.nll = nn.NLLLoss(reduction='none') #it requires log-softmax as input!!\n",
    "\n",
    "    def coupling(self, x, index, forward=True):\n",
    "        \n",
    "        if self.idf_git == 1:\n",
    "            (xa, xb) = torch.chunk(x, 2, 1)\n",
    "            \n",
    "            if forward:\n",
    "                yb = xb + self.beta[index] * self.round(self.t[index](xa))\n",
    "            else:\n",
    "                yb = xb - self.beta[index] * self.round(self.t[index](xa))\n",
    "            \n",
    "            return torch.cat((xa, yb), 1)\n",
    "        \n",
    "        elif self.idf_git == 4:\n",
    "            (xa, xb, xc, xd) = torch.chunk(x, 4, 1)\n",
    "            \n",
    "            if forward:\n",
    "                ya = xa + self.beta[index] * self.round(self.t_a[index](torch.cat((xb, xc, xd), 1)))\n",
    "                yb = xb + self.beta[index] * self.round(self.t_b[index](torch.cat((ya, xc, xd), 1)))\n",
    "                yc = xc + self.beta[index] * self.round(self.t_c[index](torch.cat((ya, yb, xd), 1)))\n",
    "                yd = xd + self.beta[index] * self.round(self.t_d[index](torch.cat((ya, yb, yc), 1)))\n",
    "            else:\n",
    "                yd = xd - self.beta[index] * self.round(self.t_d[index](torch.cat((xa, xb, xc), 1)))\n",
    "                yc = xc - self.beta[index] * self.round(self.t_c[index](torch.cat((xa, xb, yd), 1)))\n",
    "                yb = xb - self.beta[index] * self.round(self.t_b[index](torch.cat((xa, yc, yd), 1)))\n",
    "                ya = xa - self.beta[index] * self.round(self.t_a[index](torch.cat((yb, yc, yd), 1)))\n",
    "            \n",
    "            return torch.cat((ya, yb, yc, yd), 1)\n",
    "\n",
    "    def permute(self, x):\n",
    "        return x.flip(1)\n",
    "\n",
    "    def f(self, x):\n",
    "        z = x\n",
    "        for i in range(self.num_flows):\n",
    "            z = self.coupling(z, i, forward=True)\n",
    "            z = self.permute(z)\n",
    "\n",
    "        return z\n",
    "\n",
    "    def f_inv(self, z):\n",
    "        x = z\n",
    "        for i in reversed(range(self.num_flows)):\n",
    "            x = self.permute(x)\n",
    "            x = self.coupling(x, i, forward=False)\n",
    "\n",
    "        return x\n",
    "    \n",
    "    def classify(self, x):\n",
    "        z = self.f(x)\n",
    "        y_pred = self.classnet(z) #output: probabilities (i.e., softmax)\n",
    "        return torch.argmax(y_pred, dim=1)\n",
    "    \n",
    "    def class_loss(self, x, y):\n",
    "        z = self.f(x)\n",
    "        y_pred = self.classnet(z) #output: probabilities (i.e., softmax)\n",
    "        return self.nll(torch.log(y_pred), y)\n",
    "\n",
    "    def sample(self, batchSize):\n",
    "        # sample z:\n",
    "        z = self.prior_sample(batchSize=batchSize, D=self.D)\n",
    "        # x = f^-1(z)\n",
    "        x = self.f_inv(z)\n",
    "        return x.view(batchSize, 1, self.D)\n",
    "\n",
    "    def log_prior(self, x):\n",
    "        log_p = log_integer_probability(x, self.mean, self.logscale)\n",
    "        return log_p.sum(1)\n",
    "\n",
    "    def prior_sample(self, batchSize, D=2):\n",
    "        # Sample from logistic\n",
    "        y = torch.rand(batchSize, self.D)\n",
    "        x = torch.exp(self.logscale) * torch.log(y / (1. - y)) + self.mean\n",
    "        # And then round it to an integer.\n",
    "        return torch.round(x)\n",
    "    \n",
    "    def forward(self, x, y, reduction='avg'):\n",
    "        z = self.f(x)\n",
    "        y_pred = self.classnet(z) #output: probabilities (i.e., softmax)\n",
    "        \n",
    "        idf_loss = -self.log_prior(z)\n",
    "        class_loss = self.nll(torch.log(y_pred), y) #remember to use logarithm on top of softmax!\n",
    "        \n",
    "        if reduction == 'sum':\n",
    "            return (class_loss + self.alpha * idf_loss).sum()\n",
    "        else:\n",
    "            return (class_loss + self.alpha * idf_loss).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary functions: training, evaluation, plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's rather self-explanatory, isn't it?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    loss_error = 0.\n",
    "    loss_gen = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, (test_batch, test_targets) in enumerate(test_loader):\n",
    "        # hybrid loss\n",
    "        loss_t = model_best.forward(test_batch, test_targets, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        # classification error\n",
    "        y_pred = model_best.classify(test_batch)\n",
    "        e = 1.*(y_pred == test_targets)\n",
    "        loss_error = loss_error + (1. - e).sum().item()\n",
    "        # generative nll\n",
    "        loss_i = (loss_t.item() - model_best.class_loss(test_batch, test_targets).sum().item()) / model_best.alpha\n",
    "        loss_gen = loss_gen + loss_i\n",
    "        # the number of examples\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "    loss_error = loss_error / N\n",
    "    loss_gen = loss_gen / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL PERFORMANCE: nll={loss}, ce={loss_error}, gen_nll={loss_gen}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}, val ce={loss_error}, val gen_nll={loss_gen}')\n",
    "\n",
    "    return loss, loss_error, loss_gen\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x, _ = next(iter(test_loader))\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name=''):\n",
    "    x, _ = next(iter(data_loader))\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "def plot_curve(name, nll_val, file_name='_nll_val_curve.pdf', color='b-'):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, color, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('nll')\n",
    "    plt.savefig(name + file_name, bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    gen_val = []\n",
    "    error_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, (batch, targets) in enumerate(training_loader):\n",
    "            \n",
    "            loss = model.forward(batch, targets)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_e, error_e, gen_e = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_e)  # save for plotting\n",
    "        gen_val.append(gen_e)  # save for plotting\n",
    "        error_val.append(error_e)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_e\n",
    "        else:\n",
    "            if loss_e < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_e\n",
    "                patience = 0\n",
    "\n",
    "                samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "    error_val = np.asarray(error_val)\n",
    "    gen_val = np.asarray(gen_val)\n",
    "\n",
    "    return nll_val, error_val, gen_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train')\n",
    "val_data = Digits(mode='val')\n",
    "test_data = Digits(mode='test')\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)\n",
    "\n",
    "result_dir = 'results/'\n",
    "if not(os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)\n",
    "name = 'hybrididf'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = 64   # input dimension\n",
    "M = 256  # the number of neurons in scale (s) and translation (t) nets\n",
    "K = 10 # the number of labels\n",
    "\n",
    "alpha = 1./D\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize IDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The number of invertible transformations\n",
    "num_flows = 2\n",
    "\n",
    "# This variable defines whether we use: \n",
    "#   1 - the classic coupling layer proposed in (Hogeboom et al., 2019)\n",
    "#   4 - the general invertible transformation in (Tomczak, 2020) with 4 partitions\n",
    "idf_git = 1\n",
    "\n",
    "if idf_git == 1:\n",
    "    nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),\n",
    "                                     nn.Linear(M, M), nn.LeakyReLU(),\n",
    "                                     nn.Linear(M, D // 2))\n",
    "    netts = [nett]\n",
    "\n",
    "elif idf_git == 4:\n",
    "    nett_a = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, D // 4))\n",
    "\n",
    "    nett_b = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, D // 4))\n",
    "\n",
    "    nett_c = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, D // 4))\n",
    "\n",
    "    nett_d = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, M), nn.LeakyReLU(),\n",
    "                                       nn.Linear(M, D // 4))\n",
    "    \n",
    "    netts = [nett_a, nett_b, nett_c, nett_d]\n",
    "\n",
    "classnet = nn.Sequential(nn.Linear(D, M), nn.LeakyReLU(),\n",
    "                         nn.Linear(M, M), nn.LeakyReLU(),\n",
    "                         nn.Linear(M, K),\n",
    "                         nn.Softmax(dim=1))\n",
    "\n",
    "# Init IDF\n",
    "model = HybridIDF(netts, classnet, num_flows, D=D, alpha=alpha)\n",
    "# Print the summary (like in Keras)\n",
    "print(summary(model, torch.zeros(1, 64), torch.tensor([0]), show_input=False, show_hierarchical=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's play! Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val, error_val, gen_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss, test_error, test_gen = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write('NLL: ' + str(test_loss) + '\\nCA: ' + str(test_error) + '\\nGEN NLL: ' + str(test_gen))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)\n",
    "plot_curve(result_dir + name, error_val, file_name='_ca_val_curve.pdf', color='r-')\n",
    "plot_curve(result_dir + name, gen_val, file_name='_gen_val_curve.pdf', color='g-')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
